/**
 * @license
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import {
  aggregateResponses,
  getResponseStream,
  processStream,
} from "./stream-reader";
import { expect, use } from "chai";
import { restore } from "sinon";
import * as sinonChai from "sinon-chai";
import {
  getChunkedStream,
  getErrorStream,
  getMockResponseStreaming,
} from "../../test-utils/mock-response";
import {
  BlockReason,
  FinishReason,
  GenerateContentResponse,
  HarmCategory,
  HarmProbability,
} from "../../types";
import {
  GoogleGenerativeAIAbortError,
  GoogleGenerativeAIError,
} from "../errors";

use(sinonChai);

describe("getResponseStream", () => {
  afterEach(() => {
    restore();
  });
  it("two lines", async () => {
    const src = [{ text: "A" }, { text: "B" }];
    const inputStream = getChunkedStream(
      src
        .map((v) => JSON.stringify(v))
        .map((v) => "data: " + v + "\r\n\r\n")
        .join(""),
    ).pipeThrough(new TextDecoderStream("utf8", { fatal: true }));
    const responseStream = getResponseStream<{ text: string }>(inputStream);
    const reader = responseStream.getReader();
    const responses: Array<{ text: string }> = [];
    while (true) {
      const { done, value } = await reader.read();
      if (done) {
        break;
      }
      responses.push(value);
    }
    expect(responses).to.deep.equal(src);
  });
  it("throw AbortError", async () => {
    const inputStream = getErrorStream(
      new DOMException("Simulated AbortError", "AbortError"),
    ).pipeThrough(new TextDecoderStream("utf8", { fatal: true }));
    const responseStream = getResponseStream<{ text: string }>(inputStream);
    const reader = responseStream.getReader();
    const responses: Array<{ text: string }> = [];
    try {
      while (true) {
        const { done, value } = await reader.read();
        if (done) {
          break;
        }
        responses.push(value);
      }
    } catch (e) {
      expect((e as GoogleGenerativeAIAbortError).message).to.include(
        "Request aborted",
      );
    }
  });
  it("throw non AbortError", async () => {
    const inputStream = getErrorStream(
      new DOMException("Simulated Error", "RandomError"),
    ).pipeThrough(new TextDecoderStream("utf8", { fatal: true }));
    const responseStream = getResponseStream<{ text: string }>(inputStream);
    const reader = responseStream.getReader();
    const responses: Array<{ text: string }> = [];
    try {
      while (true) {
        const { done, value } = await reader.read();
        if (done) {
          break;
        }
        responses.push(value);
      }
    } catch (e) {
      expect((e as GoogleGenerativeAIError).message).to.include(
        "Error reading from the stream",
      );
    }
  });
});

describe("processStream", () => {
  afterEach(() => {
    restore();
  });
  it("streaming response - short", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-success-basic-reply-short.txt",
    );
    const result = processStream(fakeResponse as Response);
    for await (const response of result.stream) {
      expect(response.text()).to.not.be.empty;
    }
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.include("Cheyenne");
  });
  it("streaming response - long", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-success-basic-reply-long.txt",
    );
    const result = processStream(fakeResponse as Response);
    for await (const response of result.stream) {
      expect(response.text()).to.not.be.empty;
    }
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.include("**Cats:**");
    expect(aggregatedResponse.text()).to.include("to their owners.");
  });
  it("streaming response - long - big chunk", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-success-basic-reply-long.txt",
      1e6,
    );
    const result = processStream(fakeResponse as Response);
    for await (const response of result.stream) {
      expect(response.text()).to.not.be.empty;
    }
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.include("**Cats:**");
    expect(aggregatedResponse.text()).to.include("to their owners.");
  });
  it("streaming response - utf8", async () => {
    const fakeResponse = getMockResponseStreaming("streaming-success-utf8.txt");
    const result = processStream(fakeResponse as Response);
    for await (const response of result.stream) {
      expect(response.text()).to.not.be.empty;
    }
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.include("秋风瑟瑟，叶落纷纷");
    expect(aggregatedResponse.text()).to.include("家人围坐在一起");
  });
  it("streaming response - functioncall", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-success-function-call-short.txt",
    );
    const result = processStream(fakeResponse as Response);
    for await (const response of result.stream) {
      expect(response.text()).to.be.empty;
      expect(response.functionCall()).to.be.deep.equal({
        name: "getTemperature",
        args: { city: "San Jose" },
      });
    }
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.be.empty;
    expect(aggregatedResponse.functionCall()).to.be.deep.equal({
      name: "getTemperature",
      args: { city: "San Jose" },
    });
  });

  it("streaming response - searchGrounding", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-success-search-grounding.txt",
    );
    const result = processStream(fakeResponse as Response);
    const aggregatedResponse = await result.response;
    const expectedGroundingMetadata = {
      searchEntryPoint: {
        renderedContent: "test_rendered_content",
      },
      groundingChunks: [
        {
          web: {
            uri: "test_uri_1",
            title: "test_title_1",
          },
        },
        {
          web: {
            uri: "test_uri_2",
            title: "test_title_2",
          },
        },
      ],
      groundingSupports: [
        {
          segment: {
            endIndex: 41,
            text: "The current stock price for Alphabet Inc.",
          },
          groundingChunkIndices: [0],
          confidenceScores: [0.68925554],
        },
        {
          segment: {
            startIndex: 42,
            endIndex: 82,
            text: "(Google) Class C (GOOG) is \\$166.79 USD.",
          },
          groundingChunkIndices: [1, 0],
          confidenceScores: [0.92251855, 0.92251855],
        },
        {
          segment: {
            startIndex: 83,
            endIndex: 147,
            text: "This price reflects a decrease of -0.97% over the last 24 hours.",
          },
          groundingChunkIndices: [0],
          confidenceScores: [0.9831334],
        },
        {
          segment: {
            startIndex: 150,
            endIndex: 199,
            text: "Please note that stock prices can change rapidly.",
          },
          groundingChunkIndices: [1],
          confidenceScores: [0.6181941],
        },
        {
          segment: {
            startIndex: 201,
            endIndex: 267,
            text: "This information is current as of October 2, 2024, at 5:58 PM UTC.",
          },
          groundingChunkIndices: [1],
          confidenceScores: [0.6107691],
        },
      ],
      webSearchQueries: ["what is the current google stock price"],
    };
    expect(aggregatedResponse.text()).to.contains("$166.79 USD");
    console.log(aggregatedResponse.candidates[0]);
    expect(aggregatedResponse.candidates[0].groundingMetadata).to.be.deep.equal(
      expectedGroundingMetadata,
    );
  });

  it("candidate had finishReason", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-failure-finish-reason-safety.txt",
    );
    const result = processStream(fakeResponse as Response);
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.candidates?.[0].finishReason).to.equal("SAFETY");
    expect(aggregatedResponse.text).to.throw("SAFETY");
    for await (const response of result.stream) {
      expect(response.text).to.throw("SAFETY");
    }
  });
  it("prompt was blocked", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-failure-prompt-blocked-safety.txt",
    );
    const result = processStream(fakeResponse as Response);
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text).to.throw("SAFETY");
    expect(aggregatedResponse.promptFeedback?.blockReason).to.equal("SAFETY");
    for await (const response of result.stream) {
      expect(response.text).to.throw("SAFETY");
    }
  });
  it("empty content", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-failure-empty-content.txt",
    );
    const result = processStream(fakeResponse as Response);
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.equal("");
    for await (const response of result.stream) {
      expect(response.text()).to.equal("");
    }
  });
  it("unknown enum - should ignore", async () => {
    const fakeResponse = getMockResponseStreaming("streaming-unknown-enum.txt");
    const result = processStream(fakeResponse as Response);
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.include("Cats");
    for await (const response of result.stream) {
      expect(response.text()).to.not.be.empty;
    }
  });
  it("recitation ending with a missing content field", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-failure-recitation-no-content.txt",
    );
    const result = processStream(fakeResponse as Response);
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text).to.throw("RECITATION");
    expect(aggregatedResponse.candidates[0].content.parts[0].text).to.include(
      "Copyrighted text goes here",
    );
    for await (const response of result.stream) {
      if (response.candidates[0].finishReason !== FinishReason.RECITATION) {
        expect(response.text()).to.not.be.empty;
      } else {
        expect(response.text).to.throw("RECITATION");
      }
    }
  });
  it("handles citations", async () => {
    const fakeResponse = getMockResponseStreaming(
      "streaming-success-citations.txt",
    );
    const result = processStream(fakeResponse as Response);
    const aggregatedResponse = await result.response;
    expect(aggregatedResponse.text()).to.include("Quantum mechanics is");
    expect(
      aggregatedResponse.candidates[0].citationMetadata.citationSources.length,
    ).to.equal(2);
    let foundCitationMetadata = false;
    for await (const response of result.stream) {
      expect(response.text()).to.not.be.empty;
      if (response.candidates[0].citationMetadata) {
        foundCitationMetadata = true;
      }
    }
    expect(foundCitationMetadata).to.be.true;
  });
});

describe("aggregateResponses", () => {
  it("handles no candidates, and promptFeedback", () => {
    const responsesToAggregate: GenerateContentResponse[] = [
      {
        promptFeedback: {
          blockReason: BlockReason.SAFETY,
          safetyRatings: [
            {
              category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
              probability: HarmProbability.LOW,
            },
          ],
        },
      },
    ];
    const response = aggregateResponses(responsesToAggregate);
    expect(response.candidates).to.not.exist;
    expect(response.promptFeedback.blockReason).to.equal(BlockReason.SAFETY);
  });
  describe("multiple responses, has candidates", () => {
    let response: GenerateContentResponse;
    before(() => {
      const responsesToAggregate: GenerateContentResponse[] = [
        {
          candidates: [
            {
              index: 0,
              content: {
                role: "user",
                parts: [{ text: "hello." }],
              },
              finishReason: FinishReason.STOP,
              finishMessage: "something",
              safetyRatings: [
                {
                  category: HarmCategory.HARM_CATEGORY_HARASSMENT,
                  probability: HarmProbability.NEGLIGIBLE,
                },
              ],
            },
          ],
          promptFeedback: {
            blockReason: BlockReason.SAFETY,
            safetyRatings: [
              {
                category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
                probability: HarmProbability.LOW,
              },
            ],
          },
        },
        {
          candidates: [
            {
              index: 0,
              content: {
                role: "user",
                parts: [{ text: "angry stuff" }],
              },
              finishReason: FinishReason.STOP,
              finishMessage: "something",
              safetyRatings: [
                {
                  category: HarmCategory.HARM_CATEGORY_HARASSMENT,
                  probability: HarmProbability.NEGLIGIBLE,
                },
              ],
              citationMetadata: {
                citationSources: [
                  {
                    startIndex: 0,
                    endIndex: 20,
                    uri: "sourceurl",
                    license: "",
                  },
                ],
              },
            },
          ],
          promptFeedback: {
            blockReason: BlockReason.OTHER,
            safetyRatings: [
              {
                category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
                probability: HarmProbability.HIGH,
              },
            ],
          },
        },
        {
          candidates: [
            {
              index: 0,
              content: {
                role: "user",
                parts: [{ text: "...more stuff" }],
              },
              finishReason: FinishReason.MAX_TOKENS,
              finishMessage: "too many tokens",
              safetyRatings: [
                {
                  category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
                  probability: HarmProbability.MEDIUM,
                },
              ],
              citationMetadata: {
                citationSources: [
                  {
                    startIndex: 0,
                    endIndex: 20,
                    uri: "sourceurl",
                    license: "",
                  },
                  {
                    startIndex: 150,
                    endIndex: 155,
                    uri: "sourceurl",
                    license: "",
                  },
                ],
              },
            },
          ],
          promptFeedback: {
            blockReason: BlockReason.OTHER,
            safetyRatings: [
              {
                category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
                probability: HarmProbability.HIGH,
              },
            ],
          },
        },
      ];
      response = aggregateResponses(responsesToAggregate);
    });

    it("aggregates text across responses", () => {
      expect(response.candidates.length).to.equal(1);
      expect(
        response.candidates[0].content.parts.map(({ text }) => text),
      ).to.deep.equal(["hello.", "angry stuff", "...more stuff"]);
    });

    it("takes the last response's promptFeedback", () => {
      expect(response.promptFeedback.blockReason).to.equal(BlockReason.OTHER);
    });

    it("takes the last response's finishReason", () => {
      expect(response.candidates[0].finishReason).to.equal(
        FinishReason.MAX_TOKENS,
      );
    });

    it("takes the last response's finishMessage", () => {
      expect(response.candidates[0].finishMessage).to.equal("too many tokens");
    });

    it("takes the last response's candidate safetyRatings", () => {
      expect(response.candidates[0].safetyRatings[0].category).to.equal(
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
      );
      expect(response.candidates[0].safetyRatings[0].probability).to.equal(
        HarmProbability.MEDIUM,
      );
    });

    it("collects all citationSources into one array", () => {
      expect(
        response.candidates[0].citationMetadata.citationSources.length,
      ).to.equal(2);
      expect(
        response.candidates[0].citationMetadata.citationSources[0].startIndex,
      ).to.equal(0);
      expect(
        response.candidates[0].citationMetadata.citationSources[1].startIndex,
      ).to.equal(150);
    });
  });
});
